import numpy as np
from PIL import Image

import torchvision
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset


class TransformTwice:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, inp):
        out1 = self.transform(inp)
        out2 = self.transform(inp)
        return out1, out2


cifar10_mean = (0.4914, 0.4822, 0.4465)  # equals np.mean(train_set.train_data, axis=(0,1,2))/255
cifar10_std = (0.2471, 0.2435, 0.2616)  # equals np.std(train_set.train_data, axis=(0,1,2))/255


def get_cifar10(root,
                label_num,
                batch_size=64,
                unlabel_batch_size=192,
                download=True, ):
    """
    :param root: cifra保存的路径
    :param label_num: 需要视频label的数量
    :param K: 是无标注数据的batch_size = batch_size * K
    :param download: 是否下载，默认是True
    :return:
    """

    transform_train = TransformTwice(transforms.Compose([
        transforms.RandomResizedCrop(32, scale=(0.75, 1.0), ratio=(1.0, 1.0)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(0.1, 0.1, 0.1),
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ]))
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(cifar10_mean, cifar10_std),
    ])
    # 加载原始数据集
    base_dataset = torchvision.datasets.CIFAR10(root, train=True, download=download)
    
    train_labeled_idxs, train_unlabeled_idxs = train_split(base_dataset.targets, int(label_num / 10)) # 区分有标签数据与无标签数据。
    train_labeled_dataset = CIFAR10_labeled(root, train_labeled_idxs, train=True, transform=transform_train) #  有标签数据集
    train_unlabeled_dataset = CIFAR10_unlabeled(root, train_unlabeled_idxs, train=True, transform=transform_train) # 无标签数据集
    test_dataset = CIFAR10_labeled(root, train=False, transform=transform_test, download=True) # 验证集

    print(f"#Labeled: {len(train_labeled_idxs)} #Unlabeled: {len(train_unlabeled_idxs)} ")

    labeled_loader = DataLoader(train_labeled_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
    unlabeled_loader = DataLoader(train_unlabeled_dataset, batch_size=unlabel_batch_size, shuffle=True, drop_last=True, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    return labeled_loader, unlabeled_loader, test_loader


def train_split(labels, n_labeled_per_class):
    """

    :param labels: 全部的标签数据
    :param n_labeled_per_class: 每个标签的数目
    :return: 有标签索引，无标签索引
    """
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []

    for i in range(10):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:])

    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)

    return train_labeled_idxs, train_unlabeled_idxs


def train_val_split(labels, n_labeled_per_class):
    """

    :param labels: 全部标签数据
    :param n_labeled_per_class: 每个标签的类
    :return:  有标签数据索引，无标签索引，验证集索引
    """
    labels = np.array(labels)
    train_labeled_idxs = []
    train_unlabeled_idxs = []
    val_idxs = []

    for i in range(10):
        idxs = np.where(labels == i)[0]
        np.random.shuffle(idxs)
        train_labeled_idxs.extend(idxs[:n_labeled_per_class])
        train_unlabeled_idxs.extend(idxs[n_labeled_per_class:-500])
        val_idxs.extend(idxs[-500:])
    np.random.shuffle(train_labeled_idxs)
    np.random.shuffle(train_unlabeled_idxs)
    np.random.shuffle(val_idxs)

    return train_labeled_idxs, train_unlabeled_idxs, val_idxs


class CIFAR10_labeled(torchvision.datasets.CIFAR10):

    def __init__(self, root, indexs=None, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CIFAR10_labeled, self).__init__(root, train=train,
                                              transform=transform, target_transform=target_transform,
                                              download=download)
        if indexs is not None:
            self.data = self.data[indexs]
            self.targets = np.array(self.targets)[indexs]
        # self.data = transpose(normalize(self.data))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


class CIFAR10_unlabeled(CIFAR10_labeled):

    def __init__(self, root, indexs, train=True,
                 transform=None, target_transform=None,
                 download=False):
        super(CIFAR10_unlabeled, self).__init__(root, indexs, train=train,
                                                transform=transform, target_transform=target_transform,
                                                download=download)
        self.targets = np.array([-1 for i in range(len(self.targets))])




if __name__ == '__main__':
    train_labeled_dataloader, train_unlabeled_dataloader, test_dataloader = get_cifar10("./data", 4000, batch_size=4)
    # label_iter = iter(train_labeled_dataloader)
    # unlabel_iter = iter(train_unlabeled_dataloader)
    # (img1, img2), target_label = next(label_iter)
    # (img1_ul, img2_ul), target_no_label = next(unlabel_iter)
    #
    # input1 = torch.cat([img1, img1_ul])
    # input2 = torch.cat([img2, img2_ul])
    #
    # torchvision.utils.save_image(input1, "1.jpg")
    # torchvision.utils.save_image(input2, "2.jpg")
